from function import *
from load_data import *
import matplotlib.pyplot as plt
"""

size:神经网络的结构，从输入层到输出层的各层神经元个数
epochs:学习轮次
batch_size:一次训练所选取的样本数
alpha:学习率
"""
size = [784, 12, 10]
epochs = 15
batch_size = 10
alpha = [0.01, 0.001, 0.005]
file_save = '1'

file_train = 'experiment_05_training_set.csv'
file_test = 'experiment_05_testing_set.csv'
train_data = load_data(file_train, size)
test_data = load_data(file_test, size)
# print(len(train_data))


for a in alpha:
    NN = FCN(size)
    print(a)
    NN.SGD(train_data=train_data, epochs=epochs, batch_size=batch_size, alpha=a, test_data=test_data)
    plt.plot(NN.list_loss, label=a)
    # print(NN.list_loss)
plt.legend()
plt.savefig(f'img/{0}.png'.format(file_save))
plt.show()
